{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predicting Regression - Multicomponent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/predicting_regression_multicomponent.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install chemprop from GitHub if running in Google Colab\n",
    "import os\n",
    "\n",
    "if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
    "    try:\n",
    "        import chemprop\n",
    "    except ImportError:\n",
    "        !git clone https://github.com/chemprop/chemprop.git\n",
    "        %cd chemprop\n",
    "        !pip install .\n",
    "        %cd examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from lightning import pytorch as pl\n",
    "from pathlib import Path\n",
    "\n",
    "from chemprop import data, featurizers\n",
    "from chemprop.models import multi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Change model input here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parent\n",
    "checkpoint_path = chemprop_dir / \"tests\" / \"data\" / \"example_model_v2_regression_mol+mol.ckpt\" # path to the checkpoint file. \n",
    "# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MulticomponentMPNN(\n",
       "  (message_passing): MulticomponentMessagePassing(\n",
       "    (blocks): ModuleList(\n",
       "      (0-1): 2 x BondMessagePassing(\n",
       "        (W_i): Linear(in_features=86, out_features=300, bias=False)\n",
       "        (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
       "        (W_o): Linear(in_features=372, out_features=300, bias=True)\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "        (tau): ReLU()\n",
       "        (V_d_transform): Identity()\n",
       "        (graph_transform): GraphTransform(\n",
       "          (V_transform): Identity()\n",
       "          (E_transform): Identity()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (agg): MeanAggregation()\n",
       "  (bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (predictor): RegressionFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=600, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MSE(task_weights=[[1.0]])\n",
       "    (output_transform): UnscaleTransform()\n",
       "  )\n",
       "  (X_d_transform): Identity()\n",
       "  (metrics): ModuleList(\n",
       "    (0-1): 2 x MSE(task_weights=[[1.0]])\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mcmpnn = multi.MulticomponentMPNN.load_from_checkpoint(checkpoint_path)\n",
    "mcmpnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Change predict input here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parent\n",
    "test_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol+mol\" / \"mol+mol.csv\" # path to your .csv file containing SMILES strings to make predictions for\n",
    "smiles_columns = ['smiles', 'solvent'] # name of the column containing SMILES strings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load test smiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>solvent</th>\n",
       "      <th>peakwavs_max</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C...</td>\n",
       "      <td>ClCCl</td>\n",
       "      <td>642.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c...</td>\n",
       "      <td>ClCCl</td>\n",
       "      <td>420.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]...</td>\n",
       "      <td>O</td>\n",
       "      <td>544.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>c1ccc2[nH]ccc2c1</td>\n",
       "      <td>O</td>\n",
       "      <td>290.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c...</td>\n",
       "      <td>ClC(Cl)Cl</td>\n",
       "      <td>736.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)...</td>\n",
       "      <td>C1CCOC1</td>\n",
       "      <td>359.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc...</td>\n",
       "      <td>C1CCCCC1</td>\n",
       "      <td>386.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O</td>\n",
       "      <td>CCO</td>\n",
       "      <td>425.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)...</td>\n",
       "      <td>c1ccccc1</td>\n",
       "      <td>324.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)...</td>\n",
       "      <td>ClCCl</td>\n",
       "      <td>391.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               smiles    solvent  peakwavs_max\n",
       "0   CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C...      ClCCl         642.0\n",
       "1   C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c...      ClCCl         420.0\n",
       "2   CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]...          O         544.0\n",
       "3                                    c1ccc2[nH]ccc2c1          O         290.0\n",
       "4   CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c...  ClC(Cl)Cl         736.0\n",
       "..                                                ...        ...           ...\n",
       "95  COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)...    C1CCOC1         359.0\n",
       "96  COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc...   C1CCCCC1         386.0\n",
       "97          CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O        CCO         425.0\n",
       "98  Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)...   c1ccccc1         324.0\n",
       "99  Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)...      ClCCl         391.0\n",
       "\n",
       "[100 rows x 3 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_test = pd.read_csv(test_path)\n",
    "df_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get smiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([['CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2CCCC)C(=O)N(CCCC)C1=S',\n",
       "        'ClCCl'],\n",
       "       ['C(=C/c1cnccn1)\\\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3cnccn3)cc2)cc1',\n",
       "        'ClCCl'],\n",
       "       ['CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+](C)C)cc-3oc2c1',\n",
       "        'O'],\n",
       "       ['c1ccc2[nH]ccc2c1', 'O'],\n",
       "       ['CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5ccccc5c4C3(C)C)CCCC1=C2c1ccccc1C(=O)O',\n",
       "        'ClC(Cl)Cl']], dtype=object)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smiss = df_test[smiles_columns].values\n",
    "smiss[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get molecule datapoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_componenets = len(smiles_columns)\n",
    "test_datapointss = [[data.MoleculeDatapoint.from_smi(smi) for smi in smiss[:, i]] for i in range(n_componenets)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get molecule datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()\n",
    "test_dsets = [data.MoleculeDataset(test_datapoints, featurizer) for test_datapoints in test_datapointss]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get multicomponent dataset and data loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_mcdset = data.MulticomponentDataset(test_dsets)\n",
    "test_loader = data.build_dataloader(test_mcdset, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set up trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (mps), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicting DataLoader 0:   0%|                            | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/brianli/Documents/chemprop/chemprop/nn/message_passing/base.py:263: UserWarning: The operator 'aten::scatter_reduce.two_out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:13.)\n",
      "  M_all = torch.zeros(len(bmg.V), H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicting DataLoader 0: 100%|████████████████████| 2/2 [00:00<00:00,  4.00it/s]\n"
     ]
    }
   ],
   "source": [
    "with torch.inference_mode():\n",
    "    trainer = pl.Trainer(\n",
    "        logger=None,\n",
    "        enable_progress_bar=True,\n",
    "        accelerator=\"auto\",\n",
    "        devices=1\n",
    "    )\n",
    "    test_preds = trainer.predict(mcmpnn, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>solvent</th>\n",
       "      <th>peakwavs_max</th>\n",
       "      <th>pred</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C...</td>\n",
       "      <td>ClCCl</td>\n",
       "      <td>642.0</td>\n",
       "      <td>454.898621</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c...</td>\n",
       "      <td>ClCCl</td>\n",
       "      <td>420.0</td>\n",
       "      <td>453.561584</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]...</td>\n",
       "      <td>O</td>\n",
       "      <td>544.0</td>\n",
       "      <td>448.694977</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>c1ccc2[nH]ccc2c1</td>\n",
       "      <td>O</td>\n",
       "      <td>290.0</td>\n",
       "      <td>448.159760</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c...</td>\n",
       "      <td>ClC(Cl)Cl</td>\n",
       "      <td>736.0</td>\n",
       "      <td>456.897003</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)...</td>\n",
       "      <td>C1CCOC1</td>\n",
       "      <td>359.0</td>\n",
       "      <td>454.548584</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc...</td>\n",
       "      <td>C1CCCCC1</td>\n",
       "      <td>386.0</td>\n",
       "      <td>455.287140</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O</td>\n",
       "      <td>CCO</td>\n",
       "      <td>425.0</td>\n",
       "      <td>453.560364</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)...</td>\n",
       "      <td>c1ccccc1</td>\n",
       "      <td>324.0</td>\n",
       "      <td>454.656891</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)...</td>\n",
       "      <td>ClCCl</td>\n",
       "      <td>391.0</td>\n",
       "      <td>453.118774</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               smiles    solvent  \\\n",
       "0   CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C...      ClCCl   \n",
       "1   C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c...      ClCCl   \n",
       "2   CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]...          O   \n",
       "3                                    c1ccc2[nH]ccc2c1          O   \n",
       "4   CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c...  ClC(Cl)Cl   \n",
       "..                                                ...        ...   \n",
       "95  COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)...    C1CCOC1   \n",
       "96  COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc...   C1CCCCC1   \n",
       "97          CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O        CCO   \n",
       "98  Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)...   c1ccccc1   \n",
       "99  Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)...      ClCCl   \n",
       "\n",
       "    peakwavs_max        pred  \n",
       "0          642.0  454.898621  \n",
       "1          420.0  453.561584  \n",
       "2          544.0  448.694977  \n",
       "3          290.0  448.159760  \n",
       "4          736.0  456.897003  \n",
       "..           ...         ...  \n",
       "95         359.0  454.548584  \n",
       "96         386.0  455.287140  \n",
       "97         425.0  453.560364  \n",
       "98         324.0  454.656891  \n",
       "99         391.0  453.118774  \n",
       "\n",
       "[100 rows x 4 columns]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_preds = np.concatenate(test_preds, axis=0)\n",
    "df_test['pred'] = test_preds\n",
    "df_test"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
